{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.9 64-bit ('ndd': virtualenv)",
   "metadata": {
    "interpreter": {
     "hash": "855c3730de18900ddca5ab1525c8a72e412e6f1abfd1a86b3ec4c3008017ba5b"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# Matrixplot and Adjplot: Visualize and sort matrices with metadata\n",
    "\n",
    "This guide introduces `Matrixplot` and `Adjplot`. They allow the user to sort a matrix according to some metadata and plot it as either a heatmap or a scattermap. These functions also allows the user to add color or tick axes to indicate the separation between different groups or attributes.\n",
    "\n",
    "Note: `Matrixplot` and `Adjplot` have almost identical inputs/functionality. `Adjplot` is just a convenient wrapper around `Matrixplot` which assumes the matrix to be plotted is square and has the same row and column metadata."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "from graspologic.simulations import sbm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from graspologic.plot import adjplot, matrixplot\n",
    "import seaborn as sns\n"
   ],
   "cell_type": "code",
   "metadata": {},
   "outputs": []
  },
  {
   "source": [
    "## Simulate a binary graph using stochastic block model\n",
    "The 4-block model is defined as below:\n",
    "\n",
    "\\begin{align*}\n",
    "n &= [50, 50, 50, 50]\\\\\n",
    "P &= \n",
    "\\begin{bmatrix}0.8 & 0.1 & 0.05 & 0.01\\\\\n",
    "0.1 & 0.4 & 0.15 & 0.02\\\\\n",
    "0.05 & 0.15 & 0.3 & 0.01\\\\\n",
    "0.01 & 0.02 & 0.01 & 0.4\n",
    "\\end{bmatrix}\n",
    "\\end{align*}\n",
    "\n",
    "Thus, the first 50 vertices belong to block 1, the second 50 vertices belong to block 2, the third 50 vertices belong to block 3, and the last 50 vertices belong to block 4.\n",
    "\n",
    "Each block is associated with some metadata:\n",
    "\n",
    "| Block Number | Hemisphere | region |\n",
    "|:------------:|:----------:|:----:|\n",
    "|       1      |      0     |   0  |\n",
    "|       2      |      0     |   1  |\n",
    "|       3      |      1     |   0  |\n",
    "|       4      |      1     |   1  |\n",
    "\n",
    "\n"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 50\n",
    "n_communities = [N, N, N, N]\n",
    "p = [[0.8, 0.1, 0.05, 0.01],\n",
    "     [0.1, 0.4, 0.15, 0.02],\n",
    "     [0.05, 0.15, 0.3, 0.01],\n",
    "     [0.01, 0.02, 0.01, 0.4]]\n",
    "\n",
    "np.random.seed(2)\n",
    "A = sbm(n_communities, p)\n",
    "meta = pd.DataFrame(\n",
    "    data={\n",
    "        'hemisphere': np.concatenate((np.full((1, 2*N), 0), np.full((1, 2*N), 1)), axis=1).flatten(),\n",
    "        'region': np.concatenate((np.full((1, N), 0), np.full((1, N), 1), np.full((1, N), 0), np.full((1, N), 1)), axis=1).flatten(),\n",
    "        'cell_size': np.arange(4*N)},\n",
    ")"
   ]
  },
  {
   "source": [
    "With no randomization, the original data looks like this:"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    ")"
   ]
  },
  {
   "source": [
    "Randomize the data, so we can see the visual importance of matrix sorting:"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "rnd_idx = np.arange(4*N)\n",
    "np.random.shuffle(rnd_idx)\n",
    "A = A[np.ix_(rnd_idx, rnd_idx)]\n",
    "meta = meta.reindex(rnd_idx)"
   ]
  },
  {
   "source": [
    "The data immediately after randomization:"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    ")"
   ]
  },
  {
   "source": [
    "## The use of `group`\n",
    "\n",
    "The parameter group can be a list or strings or np.array by which to group the matrix"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "### Group the matrix by one metadata (hemisphere)"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\"],\n",
    "    sizes=(5, 5),\n",
    ")"
   ]
  },
  {
   "source": [
    "### Group the matrix by another metadata (region), but with a color axis to label the hemisphre"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"region\"],\n",
    "    color=[\"hemisphere\"],\n",
    "    sizes=(5, 5),\n",
    ")"
   ]
  },
  {
   "source": [
    "### Group by two metadata at the same time\n",
    "\n",
    "Notice that the order of the list matters."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\", \"region\"],\n",
    ")\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"region\", \"hemisphere\"],\n",
    ")"
   ]
  },
  {
   "source": [
    "## The use of `class_order`"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "### Sort the grouped classes by their sizes\n",
    "\n",
    "If the grouped classes are of different sizes, we can sort them based on the size in ascending order."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy.random import normal\n",
    "N = 10\n",
    "n_communities = [N, 3*N, 2*N, N]\n",
    "p = [[0.8, 0.1, 0.05, 0.01],\n",
    "     [0.1, 0.4, 0.15, 0.02],\n",
    "     [0.05, 0.15, 0.3, 0.01],\n",
    "     [0.01, 0.02, 0.01, 0.4]]\n",
    "wt = [[normal]*4]*4\n",
    "wtargs = [[dict(loc=5, scale=1)]*4]*4\n",
    "\n",
    "np.random.seed(2)\n",
    "A = sbm(n_communities, p, wt=wt, wtargs=wtargs)\n",
    "meta = pd.DataFrame(\n",
    "    data={\n",
    "        'hemisphere': np.concatenate((np.full((1, 4*N), 0), np.full((1, 3*N), 1)), axis=1).flatten(),\n",
    "        'region': np.concatenate((np.full((1, N), 0), np.full((1, 3*N), 1), np.full((1, 2*N), 0), np.full((1, N), 1)), axis=1).flatten(),\n",
    "        'cell_size': np.arange(7*N),\n",
    "        'axon_length': np.concatenate((np.random.normal(5, 1, (1, N)), np.random.normal(2, 1, (1, 3*N)), np.random.normal(5, 1, (1, 2*N)), np.random.normal(2, 1, (1, N))), axis=1).flatten()},\n",
    ")\n",
    "rnd_idx = np.arange(7*N)\n",
    "np.random.shuffle(rnd_idx)\n",
    "A = A[np.ix_(rnd_idx, rnd_idx)]\n",
    "meta = meta.reindex(rnd_idx)\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\", \"region\"],\n",
    "    group_order=\"size\", # note that this is a special keyword which was not in `meta`\n",
    "    color=[\"cell_size\", \"axon_length\"],\n",
    "    palette=[\"Purples\", \"Blues\"],\n",
    "    sizes=(1, 30),\n",
    ")"
   ]
  },
  {
   "source": [
    "If the metadata has other fields, we can also sort by the mean of certain fields in ascending order"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\", \"region\"],\n",
    "    group_order=[\"cell_size\"],\n",
    "    color=[\"cell_size\", \"axon_length\"],\n",
    "    palette=[\"Purples\", \"Blues\"],\n",
    ")"
   ]
  },
  {
   "source": [
    "We can also sort by multiple fields at the same time, including the size of the `group_class`."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\", \"region\"],\n",
    "    group_order=[\"cell_size\", \"axon_length\"],\n",
    "    color=[\"cell_size\", \"axon_length\"],\n",
    "    palette=[\"Purples\", \"Blues\"],\n",
    ")"
   ]
  },
  {
   "source": [
    "## The use of `item_order`\n",
    "\n",
    "The parameter `item_order` is used to sort items within each specific class"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "Without sorting by `item_order`, the matrix remain randomized in each grouped class:"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\", \"region\"],\n",
    "    color=[\"cell_size\"],\n",
    "    palette=\"Purples\",\n",
    ")"
   ]
  },
  {
   "source": [
    "But with sorting with `item_order`, the items are ordered within each grouped class:"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\", \"region\"],\n",
    "    item_order=[\"cell_size\"],\n",
    "    color=[\"cell_size\"],\n",
    "    palette=\"Purples\",\n",
    ")"
   ]
  },
  {
   "source": [
    "## The use of `highlight`\n",
    "`highlight` can be used to highlight separators of a particular class with a different style"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "highlight_kws = dict(color=\"red\", linestyle=\"-\", linewidth=5)\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\", \"region\"],\n",
    "    highlight=[\"hemisphere\"],\n",
    "    highlight_kws=highlight_kws,\n",
    ")"
   ]
  },
  {
   "source": [
    "## The use of multiple palettes\n",
    "Each color can be plotted with the same `palette`, or different palettes can be specified for each color"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=[\"hemisphere\", \"region\"],\n",
    "    ticks=False,\n",
    "    item_order=[\"cell_size\"],\n",
    "    color=[\"hemisphere\", \"region\", \"cell_size\", \"axon_length\"],\n",
    "    palette=[\"tab10\", \"tab20\", \"Purples\", \"Blues\"],\n",
    ")"
   ]
  },
  {
   "source": [
    "## Label the row and column axes with different metadata\n",
    "If you would like to group the row and columns by different metadata, you can use the `matrixplot` function to specify the parameters for both of the axes. Most arguments are the same, with the addition of `row_` or `col_` to specify the corresponding axis of `data`.\n",
    "\n",
    "Note: for adjacency matrices, one should not sort the rows and columns separately, as it breaks the representation of the graph. Here we do so just for demonstration, assuming `A` is not an adjacency matrix."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "matrixplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    col_meta=meta,\n",
    "    row_meta=meta,\n",
    "    plot_type=\"scattermap\",\n",
    "    col_group=[\"hemisphere\"],\n",
    "    row_group=[\"region\"],\n",
    "    col_item_order=[\"cell_size\"],\n",
    "    row_item_order=[\"cell_size\"],\n",
    "    row_color=[\"region\"],\n",
    "    row_ticks=False,\n",
    ")"
   ]
  },
  {
   "source": [
    "## Plot using `heatmap` instead of `scattermap`"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "matrixplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    col_meta=meta,\n",
    "    row_meta=meta,\n",
    "    plot_type=\"heatmap\",\n",
    "    col_group=[\"hemisphere\"],\n",
    "    row_group=[\"region\"],\n",
    "    col_item_order=[\"cell_size\"],\n",
    "    row_item_order=[\"cell_size\"],\n",
    "    row_color=[\"region\"],\n",
    "    row_ticks=False,\n",
    ")"
   ]
  },
  {
   "source": [
    "## Supply array-likes instead of `meta`\n",
    "If `meta` is not provided, each of the sorting/grouping/coloring keywords can be supplied with array-like data structures."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "group_0 = np.concatenate((np.full((4*N, 1), 0), np.full((3*N, 1), 1)), axis=0)\n",
    "group_1 = np.concatenate((np.full((N, 1), 0), np.full((3*N, 1), 1), np.full((2*N, 1), 0), np.full((N, 1), 1)), axis=0)\n",
    "group = np.concatenate((group_0, group_1), axis=1)\n",
    "\n",
    "group = group[rnd_idx, :]\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "adjplot(\n",
    "    data=A,\n",
    "    ax=ax,\n",
    "    plot_type=\"scattermap\",\n",
    "    group=group,\n",
    "    sizes=(1, 30),\n",
    ")"
   ]
  }
 ]
}